{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Multi-Objective NAS with Ax\n\n**Authors:** [David Eriksson](https://github.com/dme65)_,\n[Max Balandat](https://github.com/Balandat)_,\nand the Adaptive Experimentation team at Meta.\n\nIn this tutorial, we show how to use [Ax](https://ax.dev/)_ to run\nmulti-objective neural architecture search (NAS) for a simple neural\nnetwork model on the popular MNIST dataset. While the underlying\nmethodology would typically be used for more complicated models and\nlarger datasets, we opt for a tutorial that is easily runnable\nend-to-end on a laptop in less than 20 minutes.\n\nIn many NAS applications, there is a natural tradeoff between multiple\nobjectives of interest. For instance, when deploying models on-device\nwe may want to maximize model performance (for example, accuracy), while\nsimultaneously minimizing competing metrics like power consumption,\ninference latency, or model size in order to satisfy deployment\nconstraints. Often, we may be able to reduce computational requirements\nor latency of predictions substantially by accepting minimally lower\nmodel performance. Principled methods for exploring such tradeoffs\nefficiently are key enablers of scalable and sustainable AI, and have\nmany successful applications at Meta - see for instance our\n[case study](https://research.facebook.com/blog/2021/07/optimizing-model-accuracy-and-latency-using-bayesian-multi-objective-neural-architecture-search/)_\non a Natural Language Understanding model.\n\nIn our example here, we will tune the widths of two hidden layers,\nthe learning rate, the dropout probability, the batch size, and the\nnumber of training epochs. The goal is to trade off performance\n(accuracy on the validation set) and model size (the number of\nmodel parameters).\n\nThis tutorial makes use of the following PyTorch libraries:\n\n- [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning)_ (specifying the model and training loop)\n- [TorchX](https://github.com/pytorch/torchx)_ (for running training jobs remotely / asynchronously)\n- [BoTorch](https://github.com/pytorch/botorch)_ (the Bayesian Optimization library powering Ax's algorithms)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Defining the TorchX App\n\nOur goal is to optimize the PyTorch Lightning training job defined in\n[mnist_train_nas.py](https://github.com/pytorch/tutorials/tree/master/intermediate_source/mnist_train_nas.py)_.\nTo do this using TorchX, we write a helper function that takes in\nthe values of the architcture and hyperparameters of the training\njob and creates a [TorchX AppDef](https://pytorch.org/torchx/latest/basics.html)_\nwith the appropriate settings.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from pathlib import Path\n\nimport torchx\n\nfrom torchx import specs\nfrom torchx.components import utils\n\n\ndef trainer(\n    log_path: str,\n    hidden_size_1: int,\n    hidden_size_2: int,\n    learning_rate: float,\n    epochs: int,\n    dropout: float,\n    batch_size: int,\n    trial_idx: int = -1,\n) -> specs.AppDef:\n\n    # define the log path so we can pass it to the TorchX AppDef\n    if trial_idx >= 0:\n        log_path = Path(log_path).joinpath(str(trial_idx)).absolute().as_posix()\n\n    return utils.python(\n        # command line args to the training script\n        \"--log_path\",\n        log_path,\n        \"--hidden_size_1\",\n        str(hidden_size_1),\n        \"--hidden_size_2\",\n        str(hidden_size_2),\n        \"--learning_rate\",\n        str(learning_rate),\n        \"--epochs\",\n        str(epochs),\n        \"--dropout\",\n        str(dropout),\n        \"--batch_size\",\n        str(batch_size),\n        # other config options\n        name=\"trainer\",\n        script=\"mnist_train_nas.py\",\n        image=torchx.version.TORCHX_IMAGE,\n    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Setting up the Runner\n\nAx\u2019s [Runner](https://ax.dev/api/core.html#ax.core.runner.Runner)_\nabstraction allows writing interfaces to various backends.\nAx already comes with Runner for TorchX, and so we just need to\nconfigure it. For the purpose of this tutorial we run jobs locally\nin a fully asynchronous fashion.\n\nIn order to launch them on a cluster, you can instead specify a\ndifferent TorchX scheduler and adjust the configuration appropriately.\nFor example, if you have a Kubernetes cluster, you just need to change the\nscheduler from ``local_cwd`` to ``kubernetes``).\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tempfile\nfrom ax.runners.torchx import TorchXRunner\n\n# Make a temporary dir to log our results into\nlog_dir = tempfile.mkdtemp()\n\nax_runner = TorchXRunner(\n    tracker_base=\"/tmp/\",\n    component=trainer,\n    # NOTE: To launch this job on a cluster instead of locally you can\n    # specify a different scheduler and adjust args appropriately.\n    scheduler=\"local_cwd\",\n    component_const_params={\"log_path\": log_dir},\n    cfg={},\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Setting up the SearchSpace\n\nFirst, we define our search space. Ax supports both range parameters\nof type integer and float as well as choice parameters which can have\nnon-numerical types such as strings.\nWe will tune the hidden sizes, learning rate, dropout, and number of\nepochs as range parameters and tune the batch size as an ordered choice\nparameter to enforce it to be a power of 2.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from ax.core import (\n    ChoiceParameter,\n    ParameterType,\n    RangeParameter,\n    SearchSpace,\n)\n\nparameters = [\n    # NOTE: In a real-world setting, hidden_size_1 and hidden_size_2\n    # should probably be powers of 2, but in our simple example this\n    # would mean that num_params can't take on that many values, which\n    # in turn makes the Pareto frontier look pretty weird.\n    RangeParameter(\n        name=\"hidden_size_1\",\n        lower=16,\n        upper=128,\n        parameter_type=ParameterType.INT,\n        log_scale=True,\n    ),\n    RangeParameter(\n        name=\"hidden_size_2\",\n        lower=16,\n        upper=128,\n        parameter_type=ParameterType.INT,\n        log_scale=True,\n    ),\n    RangeParameter(\n        name=\"learning_rate\",\n        lower=1e-4,\n        upper=1e-2,\n        parameter_type=ParameterType.FLOAT,\n        log_scale=True,\n    ),\n    RangeParameter(\n        name=\"epochs\",\n        lower=1,\n        upper=4,\n        parameter_type=ParameterType.INT,\n    ),\n    RangeParameter(\n        name=\"dropout\",\n        lower=0.0,\n        upper=0.5,\n        parameter_type=ParameterType.FLOAT,\n    ),\n    ChoiceParameter(  # NOTE: ChoiceParameters don't require log-scale\n        name=\"batch_size\",\n        values=[32, 64, 128, 256],\n        parameter_type=ParameterType.INT,\n        is_ordered=True,\n        sort_values=True,\n    ),\n]\n\nsearch_space = SearchSpace(\n    parameters=parameters,\n    # NOTE: In practice, it may make sense to add a constraint\n    # hidden_size_2 <= hidden_size_1\n    parameter_constraints=[],\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Setting up Metrics\n\nAx has the concept of a [Metric](https://ax.dev/api/core.html#metric)_\nthat defines properties of outcomes and how observations are obtained\nfor these outcomes. This allows e.g. encodig how data is fetched from\nsome distributed execution backend and post-processed before being\npassed as input to Ax.\n\nIn this tutorial we will use\n[multi-objective optimization](https://ax.dev/tutorials/multiobjective_optimization.html)_\nwith the goal of maximizing the validation accuracy and minimizing\nthe number of model parameters. The latter represents a simple proxy\nof model latency, which is hard to estimate accurately for small ML\nmodels (in an actual application we would benchmark the latency while\nrunning the model on-device).\n\nIn our example TorchX will run the training jobs in a fully asynchronous\nfashion locally and write the results to the ``log_dir`` based on the trial\nindex (see the ``trainer()`` function above). We will define a metric\nclass that is aware of that logging directory. By subclassing\n[TensorboardCurveMetric](https://ax.dev/api/metrics.html?highlight=tensorboardcurvemetric#ax.metrics.tensorboard.TensorboardCurveMetric)_\nwe get the logic to read and parse the Tensorboard logs for free.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from ax.metrics.tensorboard import TensorboardCurveMetric\n\n\nclass MyTensorboardMetric(TensorboardCurveMetric):\n\n    # NOTE: We need to tell the new Tensorboard metric how to get the id /\n    # file handle for the tensorboard logs from a trial. In this case\n    # our convention is to just save a separate file per trial in\n    # the pre-specified log dir.\n    @classmethod\n    def get_ids_from_trials(cls, trials):\n        return {\n            trial.index: Path(log_dir).joinpath(str(trial.index)).as_posix()\n            for trial in trials\n        }\n\n    # This indicates whether the metric is queryable while the trial is\n    # still running. We don't use this in the current tutorial, but Ax\n    # utilizes this to implement trial-level early-stopping functionality.\n    @classmethod\n    def is_available_while_running(cls):\n        return False"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we can instatiate the metrics for accuracy and the number of\nmodel parameters. Here `curve_name` is the name of the metric in the\nTensorboard logs, while `name` is the metric name used internally\nby Ax. We also specify `lower_is_better` to indicate the favorable\ndirection of the two metrics.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "val_acc = MyTensorboardMetric(\n    name=\"val_acc\",\n    curve_name=\"val_acc\",\n    lower_is_better=False,\n)\nmodel_num_params = MyTensorboardMetric(\n    name=\"num_params\",\n    curve_name=\"num_params\",\n    lower_is_better=True,\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Setting up the OptimizationConfig\n\nThe way to tell Ax what it should optimize is by means of an\n[OptimizationConfig](https://ax.dev/api/core.html#module-ax.core.optimization_config)_.\nHere we use a ``MultiObjectiveOptimizationConfig`` as we will\nbe performing multi-objective optimization.\n\nAdditionally, Ax supports placing constraints on the different\nmetrics by specifying objective thresholds, which bound the region\nof interest in the outcome space that we want to explore. For this\nexample, we will constrain the validation accuracy to be at least\n0.94 (94%) and the number of model parameters to be at most 80,000.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from ax.core import MultiObjective, Objective, ObjectiveThreshold\nfrom ax.core.optimization_config import MultiObjectiveOptimizationConfig\n\n\nopt_config = MultiObjectiveOptimizationConfig(\n    objective=MultiObjective(\n        objectives=[\n            Objective(metric=val_acc, minimize=False),\n            Objective(metric=model_num_params, minimize=True),\n        ],\n    ),\n    objective_thresholds=[\n        ObjectiveThreshold(metric=val_acc, bound=0.94, relative=False),\n        ObjectiveThreshold(metric=model_num_params, bound=80_000, relative=False),\n    ],\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Creating the Ax Experiment\n\nIn Ax, the [Experiment](https://ax.dev/api/core.html#ax.core.experiment.Experiment)_\nobject is the object that stores all the information about the problem\nsetup.\n\n.. tip:\n  ``Experiment`` objects can be serialized to JSON or stored to a\n  database backend such as MySQL in order to persist and be available\n  to load on different machines. See the the [Ax Docs](https://ax.dev/docs/storage.html)_\n  on the storage functionality for details.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from ax.core import Experiment\n\nexperiment = Experiment(\n    name=\"torchx_mnist\",\n    search_space=search_space,\n    optimization_config=opt_config,\n    runner=ax_runner,\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Choosing the GenerationStrategy\n\nA [GenerationStrategy](https://ax.dev/api/modelbridge.html#ax.modelbridge.generation_strategy.GenerationStrategy)_\nis the abstract representation of how we would like to perform the\noptimization. While this can be customized (if you\u2019d like to do so, see\n[this tutorial](https://ax.dev/tutorials/generation_strategy.html)_),\nin most cases Ax can automatically determine an appropriate strategy\nbased on the search space, optimization config, and the total number\nof trials we want to run.\n\nTypically, Ax chooses to evaluate a number of random configurations\nbefore starting a model-based Bayesian Optimization strategy.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "total_trials = 48  # total evaluation budget\n\nfrom ax.modelbridge.dispatch_utils import choose_generation_strategy\n\ngs = choose_generation_strategy(\n    search_space=experiment.search_space,\n    optimization_config=experiment.optimization_config,\n    num_trials=total_trials,\n  )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Configuring the Scheduler\n\nThe `Scheduler` (TODO: link) acts as the loop control for the optimization.\nIt communicates with the backend to launch trials, check their status,\nand retrieve results. In the case of this tutorial, it is simply reading\nand parsing the locally saved logs. In a remote execution setting,\nit would call APIs. The following illustration from the Ax\n[Scheduler tutorial](https://ax.dev/tutorials/scheduler.html)_\nsummarizes how the Scheduler interacts with external systems used to run\ntrial evaluations:\n\n<img src=\"file://../../_static/img/ax_scheduler_illustration.png\">\n\n\nThe ``Scheduler`` requires the ``Experiment`` and the ``GenerationStrategy``.\nA set of options can be passed in via ``SchedulerOptions``. Here, we\nconfigure the number of total evaluations as well as ``max_pending_trials``,\nthe maximum number of trials that should run concurrently. In our\nlocal setting, this is the number of training jobs running as individual\nprocesses, while in a remote execution setting, this would be the number\nof machines you want to use in parallel.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from ax.service.scheduler import Scheduler, SchedulerOptions\n\nscheduler = Scheduler(\n    experiment=experiment,\n    generation_strategy=gs,\n    options=SchedulerOptions(\n        total_trials=total_trials, max_pending_trials=4\n    ),\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Running the optimization\n\nNow that everything is configured, we can let Ax run the optimization\nin a fully automated fashion. The Scheduler will periodially check\nthe logs for the status of all currently running trials, and if a\ntrial completes the scheduler will update its status on the\nexperiment and fetch the observations needed for the Bayesian\noptimization algorithm.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "scheduler.run_all_trials()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Evaluating the results\n\nWe can now inspect the result of the optimization using helper\nfunctions and visualizations included with Ax.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, we generate a dataframe with a summary of the results\nof the experiment. Each row in this dataframe corresponds to a\ntrial (that is, a training job that was run), and contains information\non the status of the trial, the parameter configuration that was\nevaluated, and the metric values that were observed. This provides\nan easy way to sanity check the optimization.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from ax.service.utils.report_utils import exp_to_df\n\ndf = exp_to_df(experiment)\ndf.head(10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can also visualize the Pareto frontier of tradeoffs between the\nvalidation accuracy and the number of model parameters.\n\n.. tip::\n  Ax uses Plotly to produce interactive plots, which allow you to\n  do things like zoom, crop, or hover in order to view details\n  of components of the plot. Try it out, and take a look at the\n  [visualization tutorial](https://ax.dev/tutorials/visualizations.html)_\n  if you'd like to learn more).\n\nThe final optimization results are shown in the figure below where\nthe color corresponds to the iteration number for each trial.\nWe see that our method was able to successfully explore the\ntrade-offs and found both large models with high validation\naccuracy as well as small models with comparatively lower\nvalidation accuracy.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from ax.service.utils.report_utils import _pareto_frontier_scatter_2d_plotly\n\n_pareto_frontier_scatter_2d_plotly(experiment)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "To better understand what our surrogate models have learned about\nthe black box objectives, we can take a look at the leave-one-out\ncross validation results. Since our models are Gaussian Processes,\nthey not only provide point predictions but also uncertainty estimates\nabout these predictions. A good model means that the predicted means\n(the points in the figure) are close to the 45 degree line and that the\nconfidence intervals cover the 45 degree line with the expected frequency\n(here we use 95% confidence intervals, so we would expect them to contain\nthe true observation 95% of the time).\n\nAs the figures below show, the model size (``num_params``) metric is\nmuch easier to model than the validation accuracy (``val_acc``) metric.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from ax.modelbridge.cross_validation import compute_diagnostics, cross_validate\nfrom ax.plot.diagnostic import interact_cross_validation_plotly\nfrom ax.utils.notebook.plotting import init_notebook_plotting, render\n\ncv = cross_validate(model=gs.model)  # The surrogate model is stored on the GenerationStrategy\ncompute_diagnostics(cv)\n\ninteract_cross_validation_plotly(cv)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can also make contour plots to better understand how the different\nobjectives depend on two of the input parameters. In the figure below,\nwe show the validation accuracy predicted by the model as a function\nof the two hidden sizes. The validation accuracy clearly increases\nas the hidden sizes increase.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from ax.plot.contour import interact_contour_plotly\n\ninteract_contour_plotly(model=gs.model, metric_name=\"val_acc\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Similarly, we show the number of model parameters as a function of\nthe hidden sizes in the figure below and see that it also increases\nas a function of the hidden sizes (the dependency on ``hidden_size_1``\nis much larger).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "interact_contour_plotly(model=gs.model, metric_name=\"num_params\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Acknowledgements\n\nWe thank the TorchX team (in particular Kiuk Chung and Tristan Rice)\nfor their help with integrating TorchX with Ax.\n\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}